{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from t5.data import preprocessors as prep\n",
    "import functools\n",
    "import t5\n",
    "import gin\n",
    "import sentencepiece as spm\n",
    "from glob import glob\n",
    "import os\n",
    "\n",
    "gin.parse_config_file('pretrained_models_base_operative_config.gin')\n",
    "vocab = 'sp10m.cased.t5.model'\n",
    "sp = spm.SentencePieceProcessor()\n",
    "sp.Load(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['../pure-text/dumping-instagram.txt',\n",
       " '../pure-text/dumping-parliament.txt',\n",
       " '../pure-text/dumping-iium.txt',\n",
       " '../pure-text/dumping-wiki.txt',\n",
       " '../pure-text/dumping-news.txt',\n",
       " '../pure-text/dumping-watpadd.txt',\n",
       " '../pure-text/dumping-pdf.txt']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "files = glob('../pure-text/dumping*.txt')\n",
    "files = [f for f in files if 'twitter' not in f and 'common' not in f]\n",
    "files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'dumping-instagram.txt'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.path.split(files[0])[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def cleaning(string):\n",
    "    string = string.replace('\\n', ' ')\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip()\n",
    "    return string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../pure-text/dumping-instagram.txt 695600\n",
      "../pure-text/dumping-parliament.txt 70046\n",
      "../pure-text/dumping-iium.txt 17837\n",
      "../pure-text/dumping-wiki.txt 326735\n",
      "../pure-text/dumping-news.txt 105469\n",
      "../pure-text/dumping-watpadd.txt 55091\n",
      "../pure-text/dumping-pdf.txt 55538\n"
     ]
    }
   ],
   "source": [
    "for file in files:\n",
    "    with open(file) as fopen:\n",
    "        data = fopen.read().split('\\n')\n",
    "    results, result = [], []\n",
    "    for i in data:\n",
    "        if not len(i) and len(result):\n",
    "            results.append('. '.join(result))\n",
    "            result = []\n",
    "        else:\n",
    "            result.append(i)\n",
    "    print(file, len(results))\n",
    "    s = os.path.split(file)[1]\n",
    "    filename = f'{s}-pair.tsv'\n",
    "    \n",
    "    with tf.io.gfile.GFile(filename, 'w') as outfile:\n",
    "        for i in range(len(results)):\n",
    "            outfile.write('%s\\t\\n' % (cleaning(results[i])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['dumping-news.txt-pair.tsv',\n",
       " 'dumping-iium.txt-pair.tsv',\n",
       " 'dumping-parliament.txt-pair.tsv',\n",
       " 'dumping-wiki.txt-pair.tsv',\n",
       " 'dumping-watpadd.txt-pair.tsv',\n",
       " 'dumping-pdf.txt-pair.tsv',\n",
       " 'dumping-instagram.txt-pair.tsv']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glob('dumping*pair.tsv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/t5/models/mesh_transformer.py:210: UserWarning: get_sentencepiece_model_path is deprecated. Please pass the mixture or task vocabulary directly to the Mesh TensorFlow Transformer instead.\n",
      "  \"get_sentencepiece_model_path is deprecated. Please pass the mixture or \"\n"
     ]
    }
   ],
   "source": [
    "def pair_dataset(split, shuffle_files = False):\n",
    "    del shuffle_files\n",
    "    ds = tf.data.TextLineDataset(\n",
    "        glob('dumping*pair.tsv')\n",
    "    )\n",
    "\n",
    "    ds = ds.map(\n",
    "        functools.partial(\n",
    "            tf.io.decode_csv,\n",
    "            record_defaults = ['', ''],\n",
    "            field_delim = '\\t',\n",
    "            use_quote_delim = False,\n",
    "        ),\n",
    "        num_parallel_calls = tf.data.experimental.AUTOTUNE,\n",
    "    )\n",
    "    ds = ds.map(lambda *ex: dict(zip(['text'], ex)))\n",
    "    return ds\n",
    "\n",
    "\n",
    "t5.data.TaskRegistry.remove('pair_dataset')\n",
    "t5.data.TaskRegistry.add(\n",
    "    'pair_dataset',\n",
    "    dataset_fn = pair_dataset,\n",
    "    splits = ['train'],\n",
    "    text_preprocessor = [prep.next_sentence_prediction],\n",
    "    sentencepiece_model_path = vocab,\n",
    "    metric_fns = [t5.evaluation.metrics.accuracy],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function neighboring_pairs.<locals>.<lambda> at 0x7fb822019b70> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Failed to parse source code of <function neighboring_pairs.<locals>.<lambda> at 0x7fb822019b70>, which Python reported as:\n",
      "      lambda x: x[text_key], num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
      "\n",
      "If this is a lambda function, the error may be avoided by creating the lambda in a standalone statement.\n",
      "WARNING: Entity <function neighboring_pairs.<locals>.<lambda> at 0x7fb822019b70> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Failed to parse source code of <function neighboring_pairs.<locals>.<lambda> at 0x7fb822019b70>, which Python reported as:\n",
      "      lambda x: x[text_key], num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
      "\n",
      "If this is a lambda function, the error may be avoided by creating the lambda in a standalone statement.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "5408651it [39:48, 2411.60it/s]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "11483926it [1:25:38, 2234.71it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "nq_task = t5.data.TaskRegistry.get(\"pair_dataset\")\n",
    "ds = nq_task.get_dataset(split='qa.tsv', sequence_length={\"inputs\": 768, \"targets\": 768})\n",
    "\n",
    "batch_size, index, part = 1000000, 0, 0\n",
    "fopen = open(f'pair-{part}.parse', 'w')\n",
    "for ex in tqdm(tfds.as_numpy(ds)):\n",
    "    i = sp.DecodeIds(ex['inputs'].tolist())\n",
    "    t = sp.DecodeIds(ex['targets'].tolist())\n",
    "    text = f'{i} [[EENNDD]] {t}\\n'\n",
    "    fopen.write(text)\n",
    "    \n",
    "    if index == batch_size:\n",
    "        fopen.close()\n",
    "        part += 1\n",
    "        index = 0\n",
    "        fopen = open(f'pair-{part}.parse', 'w')\n",
    "    \n",
    "    index += 1\n",
    "    \n",
    "fopen.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
